{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "be34d25b",
      "metadata": {
        "id": "8377c056591f"
      },
      "source": [
        "Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "6130c8e6",
      "metadata": {
        "cellView": "form",
        "id": "ca23c3f523a7"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "880b5dcf",
      "metadata": {
        "id": "u71STQRgnQ3a"
      },
      "source": [
        "# Fine-tune PaliGemma with JAX\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "<td>\n",
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/PaliGemma/[PaliGemma_2]Finetune_with_JAX.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "</td>\n",
        "<td>\n",
        "<a target=\"_blank\" href=\"https://github.com/google-gemini/gemma-cookbook/blob/main/PaliGemma/[PaliGemma_2]Finetune_with_JAX.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "</td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "74dcda33",
      "metadata": {
        "id": "wR53lePHuiP-"
      },
      "source": [
        "This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n",
        "\n",
        "### What's in this notebook\n",
        "\n",
        "This notebook uses the model reference implementation from [`big_vision`](https://github.com/google-research/big_vision)\n",
        "and shows how to:\n",
        "\n",
        " * Install dependencies, and download the PaliGemma model checkpoint and training data\n",
        " * Load the model onto GPU devices\n",
        " * Prepare the model's inputs for training and inference\n",
        " * Fine-tune the model\n",
        " * Inspect the output\n",
        "\n",
        "The training data for this notebook consists of 90 pairs of images and long captions describing them. To make it runnable on a Kaggle GPU runtime, you'll only fine-tune the attention layers of the language model and freeze the other parameters.\n",
        "\n",
        "This example is for learning purposes only. In a real use case, the amount of data, trainable parameters, training steps and hyper-parameters, and obtained results could be significantly different.\n",
        "\n",
        "### Before you begin\n",
        "\n",
        "Before going through this notebook, you should be familiar with Python code, as well as how large language models (LLMs) are trained. You don't need to be familiar with JAX, but basic knowledge about JAX (or similar technologies such as Keras) is helpful when reading through the example code."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a42e7554",
      "metadata": {
        "id": "6U0QUFveqSP2"
      },
      "source": [
        "## Setup\n",
        "\n",
        "The following sections explain the preliminary steps for getting a notebook to use a PaliGemma model, including model access and configuring the notebook runtime."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "16b96310",
      "metadata": {
        "id": "qRi1rF4MWlQi"
      },
      "source": [
        "### Get access to PaliGemma\n",
        "\n",
        "Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n",
        "\n",
        "1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n",
        "1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma-2) and click **Request Access**.\n",
        "1. Complete the consent form and accept the terms and conditions."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ee60a2fe",
      "metadata": {
        "id": "Kp6XQ2hQB8lv"
      },
      "source": [
        "### Select the runtime\n",
        "\n",
        "To complete this tutorial, you'll need to have a Kaggle runtime with sufficient resources to run the PaliGemma model. In this case, you can use a GPU:\n",
        "\n",
        "1. In the upper-right of the Kaggle notebook window, click on the three dots.\n",
        "1. Select **Accelerator**.\n",
        "1. Choose **GPU P100 or GPU T4 x2** from the available options."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "016fecda",
      "metadata": {
        "id": "rCd__uzW_eK-"
      },
      "source": [
        "### Fetch the `big_vision` repository and install related dependencies\n",
        "\n",
        "Download the `big_vision` repository to your Kaggle notebook from GitHub and install dependencies related to `big_vision` by running the following code."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "c92f001e",
      "metadata": {
        "id": "c2eba4d7d2d3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: jax[cuda12] in /opt/conda/lib/python3.10/site-packages (0.4.26)\r\n",
            "Collecting jax[cuda12]\r\n",
            "  Downloading jax-0.4.35-py3-none-any.whl.metadata (22 kB)\r\n",
            "Collecting jaxlib<=0.4.35,>=0.4.34 (from jax[cuda12])\r\n",
            "  Downloading jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes)\r\n",
            "Collecting ml-dtypes>=0.4.0 (from jax[cuda12])\r\n",
            "  Downloading ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\r\n",
            "Requirement already satisfied: numpy>=1.24 in /opt/conda/lib/python3.10/site-packages (from jax[cuda12]) (1.26.4)\r\n",
            "Requirement already satisfied: opt-einsum in /opt/conda/lib/python3.10/site-packages (from jax[cuda12]) (3.3.0)\r\n",
            "Requirement already satisfied: scipy>=1.10 in /opt/conda/lib/python3.10/site-packages (from jax[cuda12]) (1.14.1)\r\n",
            "Collecting jaxlib<=0.4.35,>=0.4.34 (from jax[cuda12])\r\n",
            "  Downloading jaxlib-0.4.34-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes)\r\n",
            "Collecting jax-cuda12-plugin<=0.4.35,>=0.4.34 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading jax_cuda12_plugin-0.4.35-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.2 kB)\r\n",
            "Collecting jax-cuda12-pjrt==0.4.35 (from jax-cuda12-plugin<=0.4.35,>=0.4.34->jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading jax_cuda12_pjrt-0.4.35-py3-none-manylinux2014_x86_64.whl.metadata (349 bytes)\r\n",
            "\u001b[33mWARNING: jax-cuda12-plugin 0.4.35 does not provide the extra 'with-cuda'\u001b[0m\u001b[33m\r\n",
            "\u001b[0mCollecting nvidia-cublas-cu12>=12.1.3.1 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)\r\n",
            "Collecting nvidia-cuda-cupti-cu12>=12.1.105 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)\r\n",
            "Collecting nvidia-cuda-nvcc-cu12>=12.1.105 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl.metadata (1.5 kB)\r\n",
            "Collecting nvidia-cuda-runtime-cu12>=12.1.105 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)\r\n",
            "Collecting nvidia-cudnn-cu12<10.0,>=9.1 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_cudnn_cu12-9.6.0.74-py3-none-manylinux_2_27_x86_64.whl.metadata (1.6 kB)\r\n",
            "Collecting nvidia-cufft-cu12>=11.0.2.54 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)\r\n",
            "Collecting nvidia-cusolver-cu12>=11.4.5.107 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)\r\n",
            "Collecting nvidia-cusparse-cu12>=12.1.0.106 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)\r\n",
            "Collecting nvidia-nccl-cu12>=2.18.1 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\r\n",
            "Collecting nvidia-nvjitlink-cu12>=12.1.105 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n",
            "  Downloading nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.5 kB)\r\n",
            "Downloading jaxlib-0.4.34-cp310-cp310-manylinux2014_x86_64.whl (86.1 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.1/86.1 MB\u001b[0m \u001b[31m19.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading jax_cuda12_plugin-0.4.35-cp310-cp310-manylinux2014_x86_64.whl (15.5 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.5/15.5 MB\u001b[0m \u001b[31m85.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading jax_cuda12_pjrt-0.4.35-py3-none-manylinux2014_x86_64.whl (100.8 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 MB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m93.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading jax-0.4.35-py3-none-any.whl (2.2 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m67.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (393.1 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m393.1/393.1 MB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.9 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.9/8.9 MB\u001b[0m \u001b[31m59.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (21.2 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.2/21.2 MB\u001b[0m \u001b[31m78.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (897 kB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m897.7/897.7 kB\u001b[0m \u001b[31m42.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_cudnn_cu12-9.6.0.74-py3-none-manylinux_2_27_x86_64.whl (508.1 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m508.1/508.1 MB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (200.2 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.2/200.2 MB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (158.2 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.2/158.2 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (216.6 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m216.6/216.6 MB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_x86_64.whl (199.0 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.0/199.0 MB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (19.7 MB)\r\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.7/19.7 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
            "\u001b[?25hInstalling collected packages: jax-cuda12-pjrt, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvcc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, ml-dtypes, jax-cuda12-plugin, nvidia-cusparse-cu12, nvidia-cufft-cu12, nvidia-cudnn-cu12, jaxlib, nvidia-cusolver-cu12, jax\r\n",
            "  Attempting uninstall: ml-dtypes\r\n",
            "    Found existing installation: ml-dtypes 0.3.2\r\n",
            "    Uninstalling ml-dtypes-0.3.2:\r\n",
            "      Successfully uninstalled ml-dtypes-0.3.2\r\n",
            "  Attempting uninstall: jaxlib\r\n",
            "    Found existing installation: jaxlib 0.4.26.dev20240620\r\n",
            "    Uninstalling jaxlib-0.4.26.dev20240620:\r\n",
            "      Successfully uninstalled jaxlib-0.4.26.dev20240620\r\n",
            "  Attempting uninstall: jax\r\n",
            "    Found existing installation: jax 0.4.26\r\n",
            "    Uninstalling jax-0.4.26:\r\n",
            "      Successfully uninstalled jax-0.4.26\r\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n",
            "tensorflow 2.16.1 requires ml-dtypes~=0.3.1, but you have ml-dtypes 0.5.0 which is incompatible.\u001b[0m\u001b[31m\r\n",
            "\u001b[0mSuccessfully installed jax-0.4.35 jax-cuda12-pjrt-0.4.35 jax-cuda12-plugin-0.4.35 jaxlib-0.4.34 ml-dtypes-0.5.0 nvidia-cublas-cu12-12.6.4.1 nvidia-cuda-cupti-cu12-12.6.80 nvidia-cuda-nvcc-cu12-12.6.85 nvidia-cuda-runtime-cu12-12.6.77 nvidia-cudnn-cu12-9.6.0.74 nvidia-cufft-cu12-11.3.0.4 nvidia-cusolver-cu12-11.7.1.2 nvidia-cusparse-cu12-12.5.4.2 nvidia-nccl-cu12-2.23.4 nvidia-nvjitlink-cu12-12.6.85\r\n"
          ]
        }
      ],
      "source": [
        "!pip install -U \"jax[cuda12]\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "3927a091",
      "metadata": {
        "id": "DfxKb3F839Ks"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "\n",
        "!git clone --quiet --branch=main --depth=1 \\\n",
        "    https://github.com/google-research/big_vision big_vision_repo\n",
        "\n",
        "# Append big_vision code to python import path\n",
        "if \"big_vision_repo\" not in sys.path:\n",
        "  sys.path.append(\"big_vision_repo\")\n",
        "\n",
        "# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n",
        "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\""
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a61a030a",
      "metadata": {
        "id": "zDoq0O77GF30"
      },
      "source": [
        "### Import JAX and other dependencies\n",
        "\n",
        "Import JAX and other dependencies required for PaliGemma, like TensorFlow and NumPy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "e15a2524",
      "metadata": {
        "id": "dTfe2k8J4Bw0"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipykernel_24/840491807.py:16: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
            "  from IPython.core.display import display, HTML\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "JAX version:  0.4.35\n",
            "JAX platform: gpu\n",
            "JAX devices:  2\n"
          ]
        }
      ],
      "source": [
        "import base64\n",
        "import functools\n",
        "import html\n",
        "import io\n",
        "import os\n",
        "import warnings\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import ml_collections\n",
        "\n",
        "import tensorflow as tf\n",
        "import sentencepiece\n",
        "\n",
        "from IPython.core.display import display, HTML\n",
        "from PIL import Image\n",
        "\n",
        "# Import model definition from big_vision\n",
        "from big_vision.models.proj.paligemma import paligemma\n",
        "from big_vision.trainers.proj.paligemma import predict_fns\n",
        "\n",
        "# Import big vision utilities\n",
        "import big_vision.datasets.jsonl\n",
        "import big_vision.utils\n",
        "import big_vision.sharding\n",
        "\n",
        "# Don't let TF use the GPU or TPUs\n",
        "tf.config.set_visible_devices([], \"GPU\")\n",
        "tf.config.set_visible_devices([], \"TPU\")\n",
        "\n",
        "backend = jax.extend.backend.get_backend()\n",
        "print(f\"JAX version:  {jax.__version__}\")\n",
        "print(f\"JAX platform: {backend.platform}\")\n",
        "print(f\"JAX devices:  {jax.device_count()}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "92dbccf9",
      "metadata": {
        "id": "b9kSadtIhjlX"
      },
      "source": [
        "## Download and configure the model\n",
        "\n",
        "In this step, you'll download the model checkpoint and configure it so that you can fine-tune it later on. This step shows you how to move model parameters into TPU memory, which is useful for fine-tuning models on devices with limited resources."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "923baf02",
      "metadata": {
        "id": "7tvcc0oQHl4v"
      },
      "source": [
        "### Download the model checkpoint\n",
        "\n",
        "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma-2/jax/paligemma2-3b-pt-224).\n",
        "\n",
        "Download the `float16` version of the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "fde048e2",
      "metadata": {
        "id": "gQNOTfF24AV4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading the checkpoint from Kaggle, this could take a few minutes....\n",
            "Model path: /kaggle/input/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz\n",
            "Downloading the model tokenizer...\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
            "  pid, fd = os.forkpty()\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Copying gs://big_vision/paligemma_tokenizer.model...\r\n",
            "\r\n",
            "Operation completed over 1 objects/4.1 MiB.                                      \r\n",
            "Tokenizer path: ./paligemma_tokenizer.model\n",
            "Downloading the dataset...\n",
            "Data path: ./longcap100\n"
          ]
        }
      ],
      "source": [
        "import os\n",
        "import kagglehub\n",
        "\n",
        "# Use these for PaliGemma-2 3B 224px²\n",
        "LLM_VARIANT = \"gemma2_2b\"\n",
        "MODEL_PATH = \"./paligemma2-3b-pt-224.b16.npz\"\n",
        "KAGGLE_HANDLE = \"google/paligemma-2/jax/paligemma2-3b-pt-224\"  # Path to fetch from Kaggle.\n",
        "\n",
        "# Use these for PaliGemma 1:\n",
        "# LLM_VARIANT = \"gemma_2b\"\n",
        "# MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n",
        "# KAGGLE_HANDLE = \"google/paligemma/jax/paligemma-3b-pt-224\"\n",
        "\n",
        "if not os.path.exists(MODEL_PATH):\n",
        "  print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n",
        "  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)\n",
        "  print(f\"Model path: {MODEL_PATH}\")\n",
        "\n",
        "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n",
        "if not os.path.exists(TOKENIZER_PATH):\n",
        "  print(\"Downloading the model tokenizer...\")\n",
        "  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}\n",
        "  print(f\"Tokenizer path: {TOKENIZER_PATH}\")\n",
        "\n",
        "DATA_DIR=\"./longcap100\"\n",
        "if not os.path.exists(DATA_DIR):\n",
        "  print(\"Downloading the dataset...\")\n",
        "  !gsutil -m -q cp -n -r gs://longcap100/ .\n",
        "  print(f\"Data path: {DATA_DIR}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "dd46593b",
      "metadata": {
        "id": "rv7w-cGuLj5o"
      },
      "source": [
        "### Configure the model\n",
        "\n",
        "It's time to actually start configuring the model that you're going to use.\n",
        "\n",
        "For this notebook, you need to be able to fit your model onto a GPU. Having a limited resource like space constraints means that you have to be mindful of how your model is configured.\n",
        "\n",
        "If you fine-tune every parameter, your model won't be able to run in the notebook environment. As a result, in this part of the notebook, you'll configure your model so that it has the ability to freeze some of the parameters, and only fine-tune the parameters that really need to be fine-tuned for the model to give you accurate results. In LLMs, parameters are said to be *frozen* when they are no longer actively being used to train the model.\n",
        "\n",
        "In order to configure your model, you need to:\n",
        "\n",
        "* Initialize the `model_config` as a [`FrozenConfigDict`](https://github.com/google/ml_collections/tree/master#frozenconfigdict) so that you can freeze some of the parameters and keep memory usage low\n",
        "* Initialize an instance of the PaliGemma `Model` class using the `model_config` as its configurations\n",
        "* Load the model parameters into RAM\n",
        "* Define a `decode` function to sample outputs from the model\n",
        "\n",
        "This code in this cell takes about a minute to run to completion."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "30747284",
      "metadata": {
        "id": "1aghcULcEdtv"
      },
      "outputs": [],
      "source": [
        "# Define model\n",
        "\n",
        "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n",
        "# for better transfer results.\n",
        "model_config = ml_collections.FrozenConfigDict({\n",
        "    \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
        "    \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n",
        "})\n",
        "model = paligemma.Model(**model_config)\n",
        "tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)\n",
        "\n",
        "# Load params - this can take up to 1 minute in T4 colabs.\n",
        "params = paligemma.load(None, MODEL_PATH, model_config)\n",
        "\n",
        "# Define `decode` function to sample outputs from the model.\n",
        "decode_fn = predict_fns.get_all(model)['decode']\n",
        "decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "fdfd4faf",
      "metadata": {
        "id": "uidBwmb8LwZ5"
      },
      "source": [
        "### Move model parameters into GPU/TPU memory\n",
        "\n",
        "Now you need to move the model parameters into GPU/TPU memory. First, shard the parameters across the available GPUs, then load the parameters. Here, you'll load the parameters sequentially. This process takes longer than loading them simultaneously, but it requires more RAM than you have available in this notebook.\n",
        "\n",
        "Finally, print out all of the parameters to see what type each individual parameter is cast to. Frozen parameters are kept as `float16`, while the trainable parameters are cast to `float32`. When you inspect the list, you'll see that most of the parameters have been frozen and are `float16`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "19f25fb1",
      "metadata": {
        "id": "RWOdf_fw2SAO"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            " == Model params == \n",
            "img/Transformer/encoder_norm/bias                                                (1152,)                float16\n",
            "img/Transformer/encoder_norm/scale                                               (1152,)                float16\n",
            "img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float16\n",
            "img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float16\n",
            "img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float16\n",
            "img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float16\n",
            "img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float16\n",
            "img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float16\n",
            "img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (27, 1152)             float16\n",
            "img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel                           (27, 4304, 1152)       float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias             (27, 16, 72)           float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel           (27, 1152, 16, 72)     float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias             (27, 1152)             float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel           (27, 16, 72, 1152)     float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias           (27, 16, 72)           float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel         (27, 1152, 16, 72)     float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias           (27, 16, 72)           float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel         (27, 1152, 16, 72)     float16\n",
            "img/embedding/bias                                                               (1152,)                float16\n",
            "img/embedding/kernel                                                             (14, 14, 3, 1152)      float16\n",
            "img/head/bias                                                                    (2304,)                float16\n",
            "img/head/kernel                                                                  (1152, 2304)           float16\n",
            "img/pos_embedding                                                                (1, 256, 1152)         float16\n",
            "llm/embedder/input_embedding                                                     (257152, 2304)         float16\n",
            "llm/final_norm/scale                                                             (2304,)                float16\n",
            "llm/layers/attn/attn_vec_einsum/w                                                (26, 8, 256, 2304)     float32\n",
            "llm/layers/attn/kv_einsum/w                                                      (26, 2, 4, 2304, 256)  float32\n",
            "llm/layers/attn/q_einsum/w                                                       (26, 8, 2304, 256)     float32\n",
            "llm/layers/mlp/gating_einsum                                                     (26, 2, 2304, 9216)    float16\n",
            "llm/layers/mlp/linear                                                            (26, 9216, 2304)       float16\n",
            "llm/layers/post_attention_norm/scale                                             (26, 2304)             float16\n",
            "llm/layers/post_ffw_norm/scale                                                   (26, 2304)             float16\n",
            "llm/layers/pre_attention_norm/scale                                              (26, 2304)             float16\n",
            "llm/layers/pre_ffw_norm/scale                                                    (26, 2304)             float16\n"
          ]
        }
      ],
      "source": [
        "# Create a pytree mask of the trainable params.\n",
        "def is_trainable_param(name, param):  # pylint: disable=unused-argument\n",
        "  if name.startswith(\"llm/layers/attn/\"):  return True\n",
        "  if name.startswith(\"llm/\"):              return False\n",
        "  if name.startswith(\"img/\"):              return False\n",
        "  raise ValueError(f\"Unexpected param name {name}\")\n",
        "trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)\n",
        "\n",
        "# If more than one device is available (e.g. multiple GPUs) the parameters can\n",
        "# be sharded across them to reduce HBM usage per device.\n",
        "mesh = jax.sharding.Mesh(jax.devices(), (\"data\"))\n",
        "\n",
        "data_sharding = jax.sharding.NamedSharding(\n",
        "    mesh, jax.sharding.PartitionSpec(\"data\"))\n",
        "\n",
        "params_sharding = big_vision.sharding.infer_sharding(\n",
        "    params, strategy=[('.*', 'fsdp(axis=\"data\")')], mesh=mesh)\n",
        "\n",
        "# Yes: Some donated buffers are not usable.\n",
        "warnings.filterwarnings(\n",
        "    \"ignore\", message=\"Some donated buffers were not usable\")\n",
        "\n",
        "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n",
        "def maybe_cast_to_f32(params, trainable):\n",
        "  # Cast others to float16, since some GPUs don't support bf16.\n",
        "  return jax.tree.map(lambda p, m: p.astype(jnp.float32)\n",
        "                      if m else p.astype(jnp.float16),\n",
        "                      params, trainable)\n",
        "\n",
        "# Loading all params in simultaneous - albeit much faster and more succinct -\n",
        "# requires more RAM than the T4 colab runtimes have by default.\n",
        "# Instead we do it param by param.\n",
        "params, treedef = jax.tree.flatten(params)\n",
        "sharding_leaves = jax.tree.leaves(params_sharding)\n",
        "trainable_leaves = jax.tree.leaves(trainable_mask)\n",
        "for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):\n",
        "  params[idx] = big_vision.utils.reshard(params[idx], sharding)\n",
        "  params[idx] = maybe_cast_to_f32(params[idx], trainable)\n",
        "  params[idx].block_until_ready()\n",
        "params = jax.tree.unflatten(treedef, params)\n",
        "\n",
        "# Print params to show what the model is made of.\n",
        "def parameter_overview(params):\n",
        "  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:\n",
        "    print(f\"{path:80s} {str(arr.shape):22s} {arr.dtype}\")\n",
        "\n",
        "print(\" == Model params == \")\n",
        "parameter_overview(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4bde55a0",
      "metadata": {
        "id": "iD_9XXQkn1Mv"
      },
      "source": [
        "## Prepare to tune the model\n",
        "\n",
        "Now that your model is configured, you can tune it. In this step, you'll create your model's inputs as well as the training and validation iterators, view the training examples, and define the training and validation loops."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3ef1ef32",
      "metadata": {
        "id": "83ZcnbddJKdx"
      },
      "source": [
        "### Create model inputs\n",
        "\n",
        "The model checkpoint you're using has already been trained on images of various aspect ratios that have been resized to 224x224 pixels, and to handle tokenized texts.\n",
        "\n",
        "The code below defines three functions that you'll use in the next step create the model's inputs:\n",
        "\n",
        "* **`preprocess_image`:** Normalizes the image data. In this case, pre-processing converts the passed-in image to greyscale, removes the alpha layer, and resizes the passed-in image to the size required by the model for image inputs (224x224 pixels).\n",
        "* **`preprocess_tokens`:** Splits the tokens up and adds flags to mark whether a token is a prefix or suffix token. These flags will be used later on in the code, during the training step and the evaluation loop.\n",
        "* **`postprocess_tokens`:** Removes any tokens left at and/or after the end-of-sequence (EOS) token and returns the remaining decoded tokens.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "id": "aea6b72a",
      "metadata": {
        "id": "8SRW0NuU4UcW"
      },
      "outputs": [],
      "source": [
        "def preprocess_image(image, size=224):\n",
        "  # Model has been trained to handle images of different aspects ratios\n",
        "  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize\n",
        "  # options are helpful to improve quality in some tasks.\n",
        "  image = np.asarray(image)\n",
        "  if image.ndim == 2:  # Convert image without last channel into greyscale.\n",
        "    image = np.stack((image,)*3, axis=-1)\n",
        "  image = image[..., :3]  # Remove alpha layer.\n",
        "  assert image.shape[-1] == 3\n",
        "\n",
        "  image = tf.constant(image)\n",
        "  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n",
        "  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]\n",
        "\n",
        "def preprocess_tokens(prefix, suffix=None, seqlen=None):\n",
        "  # Model has been trained to handle tokenized text composed of a prefix with\n",
        "  # full attention and a suffix with causal attention.\n",
        "  separator = \"\\n\"\n",
        "  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)\n",
        "  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.\n",
        "  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.\n",
        "\n",
        "  if suffix:\n",
        "    suffix = tokenizer.encode(suffix, add_eos=True)\n",
        "    tokens += suffix\n",
        "    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.\n",
        "    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.\n",
        "\n",
        "  mask_input = [1] * len(tokens)    # 1 if it's a token, 0 if padding.\n",
        "  if seqlen:\n",
        "    padding = [0] * max(0, seqlen - len(tokens))\n",
        "    tokens = tokens[:seqlen] + padding\n",
        "    mask_ar = mask_ar[:seqlen] + padding\n",
        "    mask_loss = mask_loss[:seqlen] + padding\n",
        "    mask_input = mask_input[:seqlen] + padding\n",
        "\n",
        "  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))\n",
        "\n",
        "def postprocess_tokens(tokens):\n",
        "  tokens = tokens.tolist()  # np.array to list[int]\n",
        "  try:  # Remove tokens at and after EOS if any.\n",
        "    eos_pos = tokens.index(tokenizer.eos_id())\n",
        "    tokens = tokens[:eos_pos]\n",
        "  except ValueError:\n",
        "    pass\n",
        "  return tokenizer.decode(tokens)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "672eed66",
      "metadata": {
        "id": "ovgWBgdHJZq3"
      },
      "source": [
        "### Create the training and validation iterators\n",
        "\n",
        "Create two iterators:\n",
        "\n",
        "*   A **training iterator** to allow the training process to go through the data in chunks rather than processing it all at once\n",
        "    *   This allows you to do some data pre-processing before use\n",
        "*   A **validation iterator** that allows the training process to iterate over the validation dataset to see how well the tuned model aligned with the provided results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "id": "fc220ff0",
      "metadata": {
        "id": "whzWOojGOtzi"
      },
      "outputs": [],
      "source": [
        "SEQLEN = 128\n",
        "\n",
        "train_dataset = big_vision.datasets.jsonl.DataSource(\n",
        "    os.path.join(DATA_DIR, \"data_train90.jsonl\"),\n",
        "    fopen_keys={\"image\": DATA_DIR})\n",
        "\n",
        "val_dataset = big_vision.datasets.jsonl.DataSource(\n",
        "    os.path.join(DATA_DIR, \"data_val10.jsonl\"),\n",
        "    fopen_keys={\"image\": DATA_DIR})\n",
        "\n",
        "\n",
        "def train_data_iterator():\n",
        "  \"\"\"Never ending iterator over training examples.\"\"\"\n",
        "  # Shuffle examples and repeat so one can train for many epochs.\n",
        "  dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()\n",
        "  for example in dataset.as_numpy_iterator():\n",
        "    image = Image.open(io.BytesIO(example[\"image\"]))\n",
        "    image = preprocess_image(image)\n",
        "\n",
        "    prefix = \"caption en\"  # Could also be a different prefix per example.\n",
        "    suffix = example[\"suffix\"].decode().lower()\n",
        "    tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)\n",
        "\n",
        "    yield {\n",
        "        \"image\": np.asarray(image),\n",
        "        \"text\": np.asarray(tokens),\n",
        "        \"mask_ar\": np.asarray(mask_ar),\n",
        "        \"mask_loss\": np.asarray(mask_loss),\n",
        "    }\n",
        "\n",
        "\n",
        "def validation_data_iterator():\n",
        "  \"\"\"Single iterator over validation examples.\"\"\"\n",
        "  for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():\n",
        "    image = Image.open(io.BytesIO(example[\"image\"]))\n",
        "    image = preprocess_image(image)\n",
        "\n",
        "    prefix = \"caption en\"  # Could also be a different prefix per example.\n",
        "    tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)\n",
        "\n",
        "    yield {\n",
        "        \"image\": np.asarray(image),\n",
        "        \"text\": np.asarray(tokens),\n",
        "        \"mask_ar\": np.asarray(mask_ar),\n",
        "        \"mask_input\": np.asarray(mask_input),\n",
        "    }\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0849f8a1",
      "metadata": {
        "id": "84olaM5dCiAl"
      },
      "source": [
        "### View training examples\n",
        "\n",
        "In this notebook, the training data contains 90 images that are paired with long descriptions of what's depicted in the image.\n",
        "\n",
        "**Note:** Normal training data sets that are meant to be used for practical use cases should contain more images, but this notebook limits the number of data points so that you can train the model in a reasonable amount of time for an example.\n",
        "\n",
        "The code below prints a random selection of images with their descriptions from the training data set so that you can see what the images and descriptions your model is trained on looks like. Each image is displayed in as a 128x128 pixel JPEG, with the description printed next to the image to the right."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "55a7464e",
      "metadata": {
        "id": "BzJfb5t0nsLq"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Training examples\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a table topped with a variety of items, including a wooden box, a brush, a bowl, a jar, and a towel. the table is black, and the items are arranged neatly. the brush is made of wood, and the bowl is made of wood. the knife is made of wood, and the towel is striped. the jar is made of metal, and the lid is on the jar.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a martini glass sits on a bar, its contents neatly arranged. the glass sits on a black coaster, reflecting the lights of the city lights in the background. the bar is illuminated by a warm glow, casting long shadows on the wall. the glass on the coaster holds a lemon slice, a testament to the refreshing nature of the drink. the overall atmosphere is relaxed and inviting.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a plate of spaghetti with vegetables, including green leaves, green beans and bacon crumbs on the plate. the plate is white and sits on a black table. the spaghetti is yellow.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a blue bicycle is parked next to awooden fence. the bike has a black seat, a black kickstand, and a black tire. the fence is brown and the grass is green. there is a small green bush and a small green tree in the background. the bike has a light on the front and a light on the back. the front tire of the bike is on the ground and the back tire is on the fence. the bike has a black pedal and a black pedal on the bike. the bike has a black seat and a black seat on the bike. the bike has a black kickstand and a</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a large sign in the shape of a crown stands proudly in the center of a city square. the sign is illuminated by the reflection of the sun on the water, creating a vibrant display. a tall building casts long shadows on the ground, while a flag on top of a building waves proudly. people stroll along the sidewalk. the sky is clear and blue, with fluffy white clouds drifting above. the reflection of the city in the water is a mirror image of the city itself, showcasing the beauty and diversity of this urban landscape.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a leopard sits majestically on a tree branch, its eyes open and its mouth closed. the leopard&#x27;s coat is adorned with intricate black spots, and its eyes are a vibrant blue. the tree behind the animal is tall and slender, its branches reaching out like a welcoming embrace. the leopard&#x27;s whiskers are white. </p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a man stands on a red track, his leg raised high, his shoe firmly planted on the ground. the track is lined with white lines, and the grass is green. the man wears red shorts and grey and white shoes, and his socks are black. the man is running towards the finish line.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a group of people walk down a street. a black and white sign hangs from a building, while a brown sign with gold lettering advertises a business. a woman with a pink hat and a woman with a black backpack walk side by side, their backs facing the camera. a black and white sign on a pole and a black and white sign on a building are also visible. a woman with a blue jacket and a black backpack walk on the street.</p>\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "def render_inline(image, resize=(128, 128)):\n",
        "  \"\"\"Convert image into inline html.\"\"\"\n",
        "  image = Image.fromarray(image)\n",
        "  image.resize(resize)\n",
        "  with io.BytesIO() as buffer:\n",
        "    image.save(buffer, format='jpeg')\n",
        "    image_b64 = str(base64.b64encode(buffer.getvalue()), \"utf-8\")\n",
        "    return f\"data:image/jpeg;base64,{image_b64}\"\n",
        "\n",
        "def render_example(image, caption):\n",
        "  image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]\n",
        "  return f\"\"\"\n",
        "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
        "        <img style=\"width:128px; height:128px;\" src=\"{render_inline(image, resize=(64,64))}\" />\n",
        "        <p style=\"width:256px; margin:10px; font-size:small;\">{html.escape(caption)}</p>\n",
        "    </div>\n",
        "    \"\"\"\n",
        "\n",
        "html_out = \"\"\n",
        "for idx, example in zip(range(8), train_data_iterator()):\n",
        "  caption = postprocess_tokens(example[\"text\"])  # detokenize model input.\n",
        "  caption = caption[len(\"caption en\\n\"):]        # strip prefix\n",
        "  html_out += render_example(example[\"image\"], caption)\n",
        "\n",
        "print(\"Training examples\")\n",
        "display(HTML(html_out))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5a55c1c1",
      "metadata": {
        "id": "N2BwpXkfI8OT"
      },
      "source": [
        "### Define the training and evaluation loops\n",
        "\n",
        "Define the training loop to train the model on the provided dataset, and the evaluation loop to look at all of the examples in the validation dataset and make its predictions.\n",
        "\n",
        "#### Defining the training loop\n",
        "\n",
        "The `update_fn` function defines the training step. During the training step, the loss per example is calculated and stochastic gradient descent (SGD) is applied to the trainable parameters.\n",
        "\n",
        "Recall that earlier in the notebook, you included flags in the `preprocess_tokens` function that included `mask_loss`. You'll use the `mask_loss` flag here to exclude prefix and padded tokens from the loss. Without it, the loss calculation will be skewed. You also need to normalize each example, since each of them has a different number of tokens. After the prefix and padded tokens have been excluded and the examples have been normalized, you can calculate the loss per example.\n",
        "\n",
        "The training step also includes a function to apply an SGD to optimize the training.\n",
        "\n",
        "#### Defining the evaluation loop\n",
        "\n",
        "The `make_predictions` function is your evaluation loop. The evaluation loop is fairly straight forward with one notable change. If you recall from the beginning of the notebook, you only have 90 examples in your training data set. This is a very small amount of training examples, and your model ends up not having enough examples for the batch size when you run the training. This means that in the evaluation loop, you need to pad the batch by repeating examples.\n",
        "\n",
        "To make sure that your evaluation loop only counts actual examples and not the padded examples, you have to apply a mask to the padded examples that excludes them from the output."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "id": "ff9e0a81",
      "metadata": {
        "id": "dwUV_imW3WQJ"
      },
      "outputs": [],
      "source": [
        "# The main update_fn using a simple stochastic gradient descent (SGD).\n",
        "@functools.partial(jax.jit, donate_argnums=(0,))\n",
        "def update_fn(params, batch, learning_rate):\n",
        "  imgs, txts, mask_ar = batch[\"image\"], batch[\"text\"], batch[\"mask_ar\"]\n",
        "\n",
        "  def loss_fn(params):\n",
        "    text_logits, _ = model.apply({\"params\": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)\n",
        "    logp = jax.nn.log_softmax(text_logits, axis=-1)\n",
        "\n",
        "    # The model takes as input txts[:, :-1] but the loss is defined as predicting\n",
        "    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens\n",
        "    # are part of the loss (e.g. prefix and padded tokens are not included).\n",
        "    mask_loss = batch[\"mask_loss\"][:, 1:]\n",
        "    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n",
        "\n",
        "    # Compute the loss per example. i.e. the mean of per token pplx.\n",
        "    # Since each example has a different number of tokens we normalize it.\n",
        "    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.\n",
        "    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.\n",
        "    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.\n",
        "\n",
        "    # batch_loss: mean of per example loss.\n",
        "    return jnp.mean(example_loss)\n",
        "\n",
        "  loss, grads = jax.value_and_grad(loss_fn)(params)\n",
        "\n",
        "  # Apply gradients to trainable params using SGD.\n",
        "  def apply_grad(param, gradient, trainable):\n",
        "    if not trainable: return param\n",
        "    return param - learning_rate * gradient\n",
        "\n",
        "  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)\n",
        "\n",
        "  return params, loss\n",
        "\n",
        "# Evaluation/inference loop.\n",
        "def make_predictions(data_iterator, *, num_examples=None,\n",
        "                     batch_size=4, seqlen=SEQLEN, sampler=\"greedy\"):\n",
        "  outputs = []\n",
        "  while True:\n",
        "    # Construct a list of examples in the batch.\n",
        "    examples = []\n",
        "    try:\n",
        "      for _ in range(batch_size):\n",
        "        examples.append(next(data_iterator))\n",
        "        examples[-1][\"_mask\"] = np.array(True)  # Indicates true example.\n",
        "    except StopIteration:\n",
        "      if len(examples) == 0:\n",
        "        return outputs\n",
        "\n",
        "    # Not enough examples to complete a batch. Pad by repeating last example.\n",
        "    while len(examples) % batch_size:\n",
        "      examples.append(dict(examples[-1]))\n",
        "      examples[-1][\"_mask\"] = np.array(False)  # Indicates padding example.\n",
        "\n",
        "    # Convert list of examples into a dict of np.arrays and load onto devices.\n",
        "    batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n",
        "    batch = big_vision.utils.reshard(batch, data_sharding)\n",
        "\n",
        "    # Make model predictions\n",
        "    tokens = decode({\"params\": params}, batch=batch,\n",
        "                    max_decode_len=seqlen, sampler=sampler)\n",
        "\n",
        "    # Fetch model predictions to device and detokenize.\n",
        "    tokens, mask = jax.device_get((tokens, batch[\"_mask\"]))\n",
        "    tokens = tokens[mask]  # remove padding examples.\n",
        "    responses = [postprocess_tokens(t) for t in tokens]\n",
        "\n",
        "    # Append to html output.\n",
        "    for example, response in zip(examples, responses):\n",
        "      outputs.append((example[\"image\"], response))\n",
        "      if num_examples and len(outputs) >= num_examples:\n",
        "        return outputs"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bf6ad946",
      "metadata": {
        "id": "n9r9V1jwJvu9"
      },
      "source": [
        "## Tune the model\n",
        "\n",
        "Now that you've set everything up and taken a look at the training data, it's time to finally tune the model. The code below runs the training loop for the model for 64 steps and prints the learning rate (`lr` in the printed output) and loss rate for each step.\n",
        "\n",
        "Every 16 steps, the model prints what its predictions are at that step in the training. This code prints out predictions for the same set of images so that you can see the model's ability to predict descriptions improve over time.\n",
        "\n",
        "At earlier steps in the training, there's likely issues with the descriptions, such as repeated sentences as the model gets stuck in its predictive loop or unfinished sentences. The model's predictions become steadily more accurate as training progresses. By step 64, the model's predictions should closely resemble the descriptions provided by the training data.\n",
        "\n",
        "This process takes around 15 minutes to complete on T4 TPUs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "id": "f78b3fea",
      "metadata": {
        "id": "067wj_6bZAG3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "step:  1/64   lr: 0.00500   loss: 3.2539\n",
            "step:  2/64   lr: 0.01000   loss: 1.9291\n",
            "step:  3/64   lr: 0.01500   loss: 1.5984\n",
            "step:  4/64   lr: 0.02000   loss: 1.6361\n",
            "step:  5/64   lr: 0.02500   loss: 2.0249\n",
            "step:  6/64   lr: 0.03000   loss: 2.6033\n",
            "step:  7/64   lr: 0.02998   loss: 1.9704\n",
            "step:  8/64   lr: 0.02992   loss: 1.6470\n",
            "step:  9/64   lr: 0.02981   loss: 1.5255\n",
            "step: 10/64   lr: 0.02966   loss: 1.5204\n",
            "step: 11/64   lr: 0.02947   loss: 1.3989\n",
            "step: 12/64   lr: 0.02924   loss: 1.2505\n",
            "step: 13/64   lr: 0.02897   loss: 1.1247\n",
            "step: 14/64   lr: 0.02866   loss: 1.0750\n",
            "step: 15/64   lr: 0.02831   loss: 1.2703\n",
            "step: 16/64   lr: 0.02792   loss: 1.0917\n",
            "Model predictions at step 16\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman&#x27;s hand rests on a white wall, casting a shadow on the wall. the dress is pink, and the sleeves are long. the hand is on the wall, and the shadow is on the wall. the dress is flowing, and the sleeves are gathered. the wall is white, and the shadow is long. the hand is on the wall, and the shadow is long. the dress is pink, and the sleeves are gathered. the shadow is long.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman in a white dress with a pink flower on it sits on a stone wall overlooking the ocean. the dress has a floral pattern and a white bag on her hand. the sky is blue and the water is calm. the boat is on the water and the sails are white. the dress is flowing in the wind. the woman is wearing a hat and holding a bag. the dress is long and the flowers are pink. the sky is clear and the water is calm. the boat is on the water and the sails are white. the dress is flowing in the wind. the woman is sitting on a stone wall. the bag is</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a red blazer and black pants, with a black bag on their hip. the bag has a silver zipper and a white writing on it. the person is wearing a white top underneath the blazer. the bag is black and has a silver zipper. the jacket is red and has a silver button. the pants are black and have a silver zipper. the person is wearing a white top underneath the blazer. the bag is black and has a silver zipper. the jacket is red and has a silver button. the pants are black and have a silver zipper. the bag is black and has a silver zipper. the person is wearing a</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman in a pink shirt and blue jeans stands on a stone staircase. the jeans have a hole in the knee. the woman is wearing a white cardigan and a pink bag. the bag is on her arm. the steps are gray. the wall is gray. the sky is blue. the ground is gray. the woman is wearing a pink bag. the sky is blue. the ground is gray. the wall is gray. the steps are gray. the sky is blue. the ground is gray. the woman is wearing a pink bag. the bag is on her arm. the steps are gray. the wall is gray. the sky</p>\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "step: 17/64   lr: 0.02750   loss: 1.1208\n",
            "step: 18/64   lr: 0.02704   loss: 1.2137\n",
            "step: 19/64   lr: 0.02655   loss: 1.0639\n",
            "step: 20/64   lr: 0.02602   loss: 1.0356\n",
            "step: 21/64   lr: 0.02546   loss: 0.9214\n",
            "step: 22/64   lr: 0.02488   loss: 1.0569\n",
            "step: 23/64   lr: 0.02426   loss: 0.9526\n",
            "step: 24/64   lr: 0.02362   loss: 0.6038\n",
            "step: 25/64   lr: 0.02296   loss: 0.8039\n",
            "step: 26/64   lr: 0.02227   loss: 0.7570\n",
            "step: 27/64   lr: 0.02156   loss: 0.7252\n",
            "step: 28/64   lr: 0.02083   loss: 0.7221\n",
            "step: 29/64   lr: 0.02009   loss: 0.7316\n",
            "step: 30/64   lr: 0.01933   loss: 0.7288\n",
            "step: 31/64   lr: 0.01856   loss: 0.6435\n",
            "step: 32/64   lr: 0.01778   loss: 0.7477\n",
            "Model predictions at step 32\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman in a pink dress stands on a white wall, her hand on the wall. the dress is pink, and the woman&#x27;s hand is on the wall. the dress is long and flowing, and the woman&#x27;s hand is gripping the wall. the woman is wearing a bracelet and a watch. the dress is pink, and the woman&#x27;s hand is on the wall. the woman is standing on a white wall, and the wall is white. the woman&#x27;s hand is gripping the wall, and her fingers are curled. the woman is wearing a bracelet and a watch. the dress is pink, and the woman&#x27;</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman in a white dress with a floral pattern stands on a stone wall overlooking the ocean. the dress is long and flowing, and the woman is wearing a straw bag. the sky is clear and blue, and the water is calm. the woman is standing on a stone wall, and the dress is flowing in the wind. the woman is holding a white bag and wearing a pair of sandals. the dress is white and has a floral pattern. the woman is standing on a stone wall, and the dress is flowing in the wind. the woman is wearing a straw bag and a pair of sandals. the dress is long and has a</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wears a red blazer with a black fanny pack on their hip. the blazer is open and the person is wearing black pants. the person is standing in front of a green plant and is holding their hand in their pocket. the bag is black and has a zipper. the person is wearing a black top underneath the jacket. the jacket is red and has a button on the front. the person is wearing a black belt and a black fanny pack. the jacket is open and the person is wearing a black pants. the bag is on the person&#x27;s hip and the zipper is on the bag. the person is standing in</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman stands on a stone staircase, her hand on her bag. her jeans are blue, and her shirt is pink. the woman is wearing a white cardigan and a pink bag. the stairs are made of stone, and the wall is made of concrete. the woman is standing on the stairs, and her hand is on her bag. the bag is pink, and the strap is long. the woman is wearing a bracelet and a necklace. the jeans are blue, and the buttons are white. the woman is wearing a pink shirt and a white cardigan. the bag is on her arm, and her hand is on the bag. the</p>\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "step: 33/64   lr: 0.01699   loss: 0.6949\n",
            "step: 34/64   lr: 0.01620   loss: 0.6263\n",
            "step: 35/64   lr: 0.01540   loss: 0.3855\n",
            "step: 36/64   lr: 0.01460   loss: 0.2839\n",
            "step: 37/64   lr: 0.01380   loss: 0.3310\n",
            "step: 38/64   lr: 0.01301   loss: 0.4091\n",
            "step: 39/64   lr: 0.01222   loss: 0.4324\n",
            "step: 40/64   lr: 0.01144   loss: 0.3957\n",
            "step: 41/64   lr: 0.01067   loss: 0.3261\n",
            "step: 42/64   lr: 0.00991   loss: 0.4206\n",
            "step: 43/64   lr: 0.00917   loss: 0.4413\n",
            "step: 44/64   lr: 0.00844   loss: 0.3780\n",
            "step: 45/64   lr: 0.00773   loss: 0.3321\n",
            "step: 46/64   lr: 0.00704   loss: 0.2110\n",
            "step: 47/64   lr: 0.00638   loss: 0.1994\n",
            "step: 48/64   lr: 0.00574   loss: 0.1646\n",
            "Model predictions at step 48\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman in a pink dress stands on a white wall, her hand on the wall. the dress is pink, and the sleeves are long and gathered. the woman&#x27;s hand is gripping the wall. the wall is white, and the shadow on the wall is long and dark.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman stands on a pier, her dress flowing in the wind. the sky is clear and blue, with a few clouds. the water is calm and blue, with a few waves. the woman holds her bag and stands with her legs crossed. the dress is white with a red and black flower print. the woman wears short sleeves and a tie on the dress. the dress is long and flowing.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wears a red blazer with a black belt and bag. the blazer is open and the bag is strapped to their waist. the bag is black and has a zipper. the person&#x27;s hand is on the bag. the bag has a zipper and a silver chain. the blazer is loose and the buttons are unbuttoned. the person is standing next to a green plant.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman stands on a stone staircase, her hand on her bag. the jeans are blue, and the fabric is ripped. the shirt is pink, and the buttons are white. the woman is wearing a white cardigan and a silver bracelet on her wrist. the bag is pink, and the strap is pink.</p>\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "step: 49/64   lr: 0.00512   loss: 0.1421\n",
            "step: 50/64   lr: 0.00454   loss: 0.2420\n",
            "step: 51/64   lr: 0.00398   loss: 0.1420\n",
            "step: 52/64   lr: 0.00345   loss: 0.1434\n",
            "step: 53/64   lr: 0.00296   loss: 0.1580\n",
            "step: 54/64   lr: 0.00250   loss: 0.2400\n",
            "step: 55/64   lr: 0.00208   loss: 0.1307\n",
            "step: 56/64   lr: 0.00169   loss: 0.1296\n",
            "step: 57/64   lr: 0.00134   loss: 0.1500\n",
            "step: 58/64   lr: 0.00103   loss: 0.1329\n",
            "step: 59/64   lr: 0.00076   loss: 0.0738\n",
            "step: 60/64   lr: 0.00053   loss: 0.1207\n",
            "step: 61/64   lr: 0.00034   loss: 0.1089\n",
            "step: 62/64   lr: 0.00019   loss: 0.1033\n",
            "step: 63/64   lr: 0.00008   loss: 0.1217\n",
            "step: 64/64   lr: 0.00002   loss: 0.1000\n",
            "Model predictions at step 64\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman in a pink dress stands on a white staircase, her hand on the wall. the dress is pink, and the fabric is sheer. the woman&#x27;s hand is gripping the wall. the stairs are white, and the wall is painted white. the woman is wearing long sleeves, and the sleeves are gathered at the wrist. the dress has a collar, and the collar is white. the woman is standing on a step, and her hand is on the wall.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman stands on a pier, her dress flowing in the wind. the sky is clear and blue, with a few fluffy clouds. the water is calm and blue, and the boats on the water are visible. the woman&#x27;s hand is on her hip, and her other hand is on her dress. the woman is wearing a long white dress with a floral pattern, and her hair is blonde. the dress is flowing in the wind, and the flowers on the dress are red and pink. the woman is standing next to the ocean, and the boats are floating on the water.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wears a red blazer with a black belt bag. the bag has a zipper and a silver zipper pull. the person wears black pants and has their fingers in the bag. the blazer has a button and a single vent. the person stands in front of a green plant.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman stands on a stone staircase, her hand on her purse. the jeans are blue, and the fabric is torn. the shirt is pink, and the buttons are white. the woman is wearing a white cardigan and a silver bracelet on her wrist. the bag is pink, and the strap is pink. the woman is walking on the street.</p>\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Run a short training loop with cosine learning rate schedule.\n",
        "#\n",
        "# Note: the first step can be quite slow on some machines (up to several minutes)\n",
        "# due to XLA compilation of the jax.jit'd function.\n",
        "#\n",
        "\n",
        "BATCH_SIZE = 8\n",
        "TRAIN_EXAMPLES = 512\n",
        "LEARNING_RATE = 0.03\n",
        "\n",
        "TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE\n",
        "EVAL_STEPS = TRAIN_STEPS // 4\n",
        "\n",
        "train_data_it = train_data_iterator()\n",
        "\n",
        "sched_fn = big_vision.utils.create_learning_rate_schedule(\n",
        "    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,\n",
        "    decay_type=\"cosine\", warmup_percent=0.10)\n",
        "\n",
        "for step in range(1, TRAIN_STEPS+1):\n",
        "  # Make list of N training examples.\n",
        "  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]\n",
        "\n",
        "  # Convert list of examples into a dict of np.arrays and load onto devices.\n",
        "  batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n",
        "  batch = big_vision.utils.reshard(batch, data_sharding)\n",
        "\n",
        "  # Training step and report training loss\n",
        "  learning_rate = sched_fn(step)\n",
        "  params, loss = update_fn(params, batch, learning_rate)\n",
        "\n",
        "  loss = jax.device_get(loss)\n",
        "  print(f\"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}\")\n",
        "\n",
        "  if (step % EVAL_STEPS) == 0:\n",
        "    print(f\"Model predictions at step {step}\")\n",
        "    html_out = \"\"\n",
        "    for image, caption in make_predictions(\n",
        "        validation_data_iterator(), num_examples=4, batch_size=4):\n",
        "      html_out += render_example(image, caption)\n",
        "    display(HTML(html_out))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f6019d20",
      "metadata": {
        "id": "glScsFLVJ52c"
      },
      "source": [
        "## Output\n",
        "\n",
        "The validation data for this notebook consists of just 10 images. In normal code, you would likely have many more data points for validation, but for this notebook, run the following code to generate descriptions for all 10 images. After tuning the model, these descriptions should be very similar in form and content coverage to the descriptions included with the training data that you looked at earlier in this notebook.\n",
        "\n",
        "Run the below code to generate descriptions for the validation data set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "6c3b2164",
      "metadata": {
        "id": "hgUhEKjzPdMQ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model predictions\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman in a pink dress stands on a white staircase, her hand on the wall. the dress is pink, and the fabric is sheer. the woman&#x27;s hand is gripping the wall. the stairs are white, and the wall is painted white. the woman is wearing long sleeves, and the sleeves are gathered at the wrist. the dress has a collar, and the collar is white. the woman is standing on a step, and her hand is on the wall.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman stands on a pier, her dress flowing in the wind. the sky is clear and blue, with a few fluffy clouds. the water is calm and blue, and the boats on the water are visible. the woman&#x27;s hand is on her hip, and her other hand is on her dress. the woman is wearing a long white dress with a floral pattern, and her hair is blonde. the dress is flowing in the wind, and the flowers on the dress are red and pink. the woman is standing next to the ocean, and the boats are floating on the water.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wears a red blazer with a black belt bag. the bag has a zipper and a silver zipper pull. the person wears black pants and has their fingers in the bag. the blazer has a button and a single vent. the person stands in front of a green plant.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman stands on a stone staircase, her hand on her purse. the jeans are blue, and the fabric is torn. the shirt is pink, and the buttons are white. the woman is wearing a white cardigan and a silver bracelet on her wrist. the bag is pink, and the strap is pink. the woman is walking on the street.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman is lying on a bed, wearing a pink sweater with the words &quot;love will save us&quot; written on it. the sweater is long-sleeved and has a crew neckline. the woman is wearing white sneakers and has her hand on the bed. the jeans are blue and have a belt loop. the blanket is gray and fuzzy.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a man stands with his hand on his head, his long blonde hair flowing in the wind. the man wears a black sweater and a white and black plaid shirt. the sweater is navy blue and the shirt is white and black. the man&#x27;s hair is messy and his eyes are closed. the man is standing against a pink wall.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a row of white hangers on a white clothes rack, with a white wall in the background. the hangers are white, and the metal bar on the rack is white. the rack has a white metal pole on the bottom, and a white metal bar on the top. the wall is white, and the light is shining on the rack.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a white sweater hangs on a wooden hanger, with a white drawstring on the bottom of the sweater. the sweater has a hood and a pocket on the front. the pants have a white drawstring on the bottom of the pants. the clothes are hanging on a black pole, with a black circle on the wall. the clothes are on a white rack, with a white tag on the hanger.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman stands on a sidewalk, showcasing her black knee-high boots and black bag. the boots are made of suede and have a low heel. the bag has a gold chain strap and a silver lock. the woman&#x27;s hand is on the bag. the jeans are blue and have a slight stretch. the woman is wearing a black knee-high boot and a black long-sleeve shirt. the boots are black and have a low heel. the bag is black and has a gold chain strap. the woman is standing on a gray sidewalk.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a man stands on a road, his hands in his pockets. his pants are brown, and his shirt is white. he wears a denim jacket and white shoes. the road is gray and the trees are green. the man&#x27;s hands are in his pockets. the man is standing on the road, his back to the camera. the man is wearing a white t-shirt and a blue denim jacket. the man&#x27;s shoes are white. the man&#x27;s pants are brown. the man is smiling.</p>\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# The validation data consists of 10 images in a different domain than training\n",
        "# data.\n",
        "\n",
        "print(\"Model predictions\")\n",
        "html_out = \"\"\n",
        "for image, caption in make_predictions(validation_data_iterator(), batch_size=4):\n",
        "  html_out += render_example(image, caption)\n",
        "display(HTML(html_out))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "[PaliGemma_2]Finetune_with_JAX.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
